{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "# Evaluation Study"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Setup"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "source": [
    "import os, sys\r\n",
    "import shutil\r\n",
    "import math\r\n",
    "import time\r\n",
    "from glob import glob\r\n",
    "import numpy as np\r\n",
    "import pandas as pd\r\n",
    "import matplotlib.pyplot as plt\r\n",
    "import seaborn as sns\r\n",
    "\r\n",
    "import torch\r\n",
    "import torch.nn as nn\r\n",
    "import torch.nn.functional as F\r\n",
    "\r\n",
    "sys.path.append('..')"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "source": [
    "# PATH Setting\r\n",
    "DIR_PATH = 'F:/20210730_samples'\r\n",
    "SAVE_PATH = os.path.join(DIR_PATH, 'numpy')\r\n",
    "CT_PATH = os.path.join(SAVE_PATH, 'CT_target')\r\n",
    "PT_PATH = os.path.join(SAVE_PATH, 'PT_vanilla')\r\n",
    "PT_PATH_2 = os.path.join(SAVE_PATH, 'PT_resize')"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "source": [
    "device = torch.device('cuda')"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## NCC"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "source": [
    "from utils.common.metrics import NCC\r\n",
    "\r\n",
    "Ii = np.zeros([16, 1, 128, 128, 256])\r\n",
    "Ji = np.zeros([16, 1, 128, 128, 256])\r\n",
    "for i in range(16):\r\n",
    "    Ii[i,0] = np.load(os.path.join(CT_PATH, os.listdir(CT_PATH)[i]))\r\n",
    "    Ji[i,0] = np.load(os.path.join(PT_PATH, os.listdir(PT_PATH)[i]))\r\n",
    "\r\n",
    "Ii = torch.FloatTensor(Ii).to(device)\r\n",
    "Ji = torch.FloatTensor(Ji).to(device)\r\n",
    "\r\n",
    "criterion = NCC(device)\r\n",
    "cc = criterion(Ii, Ji).cpu()\r\n",
    "print ('Autograd Check', cc.requires_grad)\r\n",
    "print ('NCC', cc)\r\n",
    "print ('Average NCC', cc.mean().item())"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Autograd Check False\n",
      "NCC tensor([0.2473, 0.2616, 0.3115, 0.2672, 0.2100, 0.2890, 0.2440, 0.1606, 0.1522,\n",
      "        0.1706, 0.2456, 0.1929, 0.2322, 0.1960, 0.3376, 0.2225])\n",
      "Average NCC 0.23380093276500702\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "source": [
    "Ii = np.zeros([16, 1, 128, 128, 256])\r\n",
    "Ji = np.zeros([16, 1, 128, 128, 256])\r\n",
    "for i in range(16):\r\n",
    "    Ii[i,0] = np.load(os.path.join(CT_PATH, os.listdir(CT_PATH)[i]))\r\n",
    "    Ji[i,0] = np.load(os.path.join(PT_PATH_2, os.listdir(PT_PATH_2)[i]))\r\n",
    "\r\n",
    "Ii = torch.FloatTensor(Ii).to(device)\r\n",
    "Ji = torch.FloatTensor(Ji).to(device)\r\n",
    "\r\n",
    "criterion = NCC(device)\r\n",
    "cc = criterion(Ii, Ji).cpu()\r\n",
    "print ('Autograd Check', cc.requires_grad)\r\n",
    "print ('NCC', cc)\r\n",
    "print ('Average NCC', cc.mean().item())"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Autograd Check False\n",
      "NCC tensor([0.4268, 0.4540, 0.5476, 0.4465, 0.3494, 0.5016, 0.3968, 0.2531, 0.2562,\n",
      "        0.2873, 0.4179, 0.3197, 0.3976, 0.3193, 0.5823, 0.3773])\n",
      "Average NCC 0.3958446681499481\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "orig_nbformat": 4,
  "language_info": {
   "name": "python",
   "version": "3.8.10",
   "mimetype": "text/x-python",
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "pygments_lexer": "ipython3",
   "nbconvert_exporter": "python",
   "file_extension": ".py"
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.8.10 64-bit ('torch': conda)"
  },
  "interpreter": {
   "hash": "acce06b49c94f1480d6840cdb9b2e7c4d2e90f84d9ec42d5eadc9fe035ca1697"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}